rm(list=ls())
library(doParallel)
library(doRNG)
library(foreach)
cluster=T
if (cluster == T){
    args = commandArgs(TRUE)
    print(args)
    no_cores = as.numeric(args[[1]])
}else{
    no_cores = detectCores(all.tests = T) - 2
    cluster.path = "~/Desktop/causal res/SNMM_Local/3. Data and Programming/cluster/080820"
    setwd(cluster.path)
}

registerDoParallel(cores=no_cores)
print(no_cores)

library(abind)
{
    
    options(error=stop)
    source("2.0.0 MyFunc.R")
    source("2.0.1 GetProb.R")
    source("2.0.3 GetContrast.R")
    source("2.1 DataGen.R")
    source("2.1.2 MLEst3.R")
    source("2.2.3 PMLEst4.R")
}    
    
{
    SEED.index = 1 # DEBUG
        
    ### Start of true values specification
    #   parameters
    n.cand      = c(100, 200, 500, 5000)
    numsims     = 500
    #### Parameters that has statble properties
    # alpha.true  = t(matrix(rep(c(0,-0.5),5),2,5))
    # alpha.true  = t(matrix(rep(c(0,0.7),5),2,5)) #CHANGE
    # # alpha.true  = t(matrix(rep(c(-0.5,1),5),2,5)) #CHANGE
    # # beta.true   = t(matrix(rep(c(-0.5,0.1),5),2,5))
    # beta.3 =  c(-0.5, 1)
    # beta.1 = beta.2 = beta.4 = beta.5 =c(-0.5, 0.1);
    # beta.true   = t(matrix(c(beta.1, beta.2, beta.3, beta.4, beta.5),2,5))
    
    
    ### Parameters to illustrate g-null paradox
    alpha.true  = t(matrix(rep(c(0,0),5),2,5))
    beta.3 =  c(-0.5, 0.1)
    beta.1 = c(-0.5, -1); beta.2 = c(-0.5, 1.5)
    beta.4 = c(-0.5, 1.5); beta.5 = c(-0.5, -2)
    beta.true   = t(matrix(c(beta.1, beta.2, beta.3, beta.4, beta.5),2,5))
    
    # datagen = DataGen(1000)
    # MyAttach(datagen)
    # cnt

    gamma.true  = t(matrix(c(0.1,-0.5,NA,NA,0.1,-0.5,0.1,-0.5),4,2))
    est.method.cand = c("MLE3","PMLE4","DR")
    params.true = list(alpha.true=alpha.true, beta.true=beta.true,
                       gamma.true=gamma.true)
    rownames(gamma.true) = c("a0","a1")
    
    ### To store results
    sim.name = paste("sim=",1:numsims,sep="")
    n.name   = paste("n=",n.cand,sep="")
    est.method.name = as.character(est.method.cand)
    results  = array(dim=c(numsims,length(n.cand), length(est.method.cand),
                           nrow(alpha.true), ncol(alpha.true),2,4),
                      dimnames = list(sim.name,n.name,est.method.name,
                                      1:nrow(alpha.true), 1:ncol(alpha.true),
                                     c("alpha","beta"),c("est","nll","time",
                                                         "contrast")))
    names(dimnames(results)) = c("Simulation number","Sample size","est.method",
                                 "1-5","1/2","Param","Measure")

    MESSAGE <-- FALSE     # compress message (This is a global variable!)
}
    ptm0 = proc.time()
    
    # for(n.index in 1:length(n.cand)){
        
        {
            n.index = 4
        
        ptm1 = proc.time()
        n = n.cand[n.index]
        
        cat("Sample size: ",n,"\n")
        set.seed(n)
        seeds       = runif(10000,-2^30,2^30)
        {
            ee.assume.true.alpha = T; 
            
            # Nuisance model for y.hat.type for DR estimation; 
            y.hat.type = "MLE" # "saturated" or "MLE" or "logit" or "diy"
            prop.score.type = "correct" # "correct" or "incorrect"
            vectorize.get.prob = F;parallel.root = F
            l0.type = "discrete" # "continuous" or "discrete"
            l1.type ="correct" # "correct" or "incorrect" or "diy"
            gop.type = "incorrect" # "correct" or "incorrect"
        }
        
        {
            doMLE=0; doPMLE=1; 
            doDR=1; 
            doGComp=1;
            doDebugEE=0
        }
        {
            source("2.5 CallDR.R")
            source("2.5.1 DR.R")
            source("2.5.1 DR_Iter.R")
            source("2.5.3 DR_MyFunc.R")
            source("2.6 DR_EE_value.R")
        }
        
        GComp.res = array(dim=c(numsims, 8))
        p.l0 = array(dim=c(numsims))
        {ee.res = array(dim=c(numsims,5*2))
        ee.stage1 = ee.d1 = ee.d0 =array(dim=c(numsims,5*2))
        ee.u1 = ee.u0 =  numeric(numsims)
        # j = 1; vals = list() # to output test value
        }
        # numsims = 500
        M = 5; per.iter = numsims/M
        
        ### 25 Jul 2021: run 10 times DR estimation.
        dr.res.10 = array(dim=c(10, dim(results[,4,"PMLE4",,,,])))
        dimnames(dr.res.10) = list(Initial = 1:10, sim.name, 1:nrow(alpha.true), 1:ncol(alpha.true),
                                    c("alpha","beta"),c("est","nll","time",
                                                        "contrast"))
        names(dimnames(dr.res.10) )[2:6] = c("Simulation Number", "1-5","1/2","Param","Measure")
    } 
    t0.total = proc.time()
    for (iter in 1:M){
        # iter = 1
        boot.list = ((iter-1)*per.iter+1):(iter*per.iter)
        # for(i in boot.list){
        boot1 = foreach(i = boot.list,.errorhandling="stop") %dopar% {
            # i= 1 #DEBUG
            
            seed.perturb = 0
            seed.index = i + (SEED.index-1) * numsims +seed.perturb
            set.seed(seeds[seed.index])
            cat("\t Simulation number: ",seed.index-seed.perturb,"\n")    
            
            ### 0: Data generation: x, y
            source("2.1 DataGen.R")

            datagen = DataGen(n, SEED=seeds[seed.index])
            MyAttach(datagen)
            
            ### 1 = Maximal likelihood estimate
            # library(profvis) # DEBUG
            # # library(bench)
            # vectorize.get.prob = T;parallel.root = F
            
            # profvis(MLEst3(data, cnt))
            
            mle3 = NULL
            if (doMLE){
                set.seed(seeds[seed.index])
              t0 = proc.time()
              mle3 = MLEst3(data, cnt)
              t1 = proc.time()
              print(paste0("MLE time: ", (t1 - t0)[3], " secs")) # 403.982 seconds
              results[i,n.index,"MLE3",,,,] = mle3
            }
            
            pmle4 = NULL
            source("2.0.1 GetProb.R")
            if (doPMLE){
                set.seed(seeds[seed.index])
              t2 = proc.time()
              pmle4 = PMLEst4(data, cnt)
              results[i,n.index,"PMLE4",,,,] = pmle4
              t3 = proc.time()
              t3 - t2 # 360.650 
              print(paste0("PMLE time: ", round(as.numeric(t3 - t2)[3],2), " secs")) 
            }
            (alpha.pmle = pmle4[,,1,1])
            (beta.pmle = pmle4[,,2,1])
            
            
            ### 2 = DR Estimate based on Vansteelandt and Joffe (2014)
            #   2.1 Estimate gamma propensity score model
            if (prop.score.type == "correct"){
              gamma.a0 = glm.fit(a0.cov,a0.vec,family=binomial(),intercept=FALSE)$coeff
              gamma.a1 = glm.fit(a1.cov,a1.vec,family=binomial(),intercept=FALSE)$coeff
            }else{
              a0.cov.incor = a0.cov; a1.cov.incor = a1.cov
              a0.cov.incor[,2] = 0; a1.cov.incor[,2:ncol(a1.cov.incor)] = 0
              gamma.a0 = glm.fit(a0.cov.incor,a0.vec,family=binomial(),intercept=FALSE)$coeff
              gamma.a1 = glm.fit(a1.cov.incor,a1.vec,family=binomial(),intercept=FALSE)$coeff
              gamma.a0[2] = 0
              gamma.a1[2:ncol(a1.cov.incor)] = 0
            }
            gamma = t(matrix(c(gamma.a0,NA,NA,gamma.a1),4,2))
            rownames(gamma) = c("a0","a1")
            #   2.2 Estimate outcome regression for y
            {source("2.5 CallDR.R")
                source("2.5.1 DR.R")
                source("2.5.1 DR_Iter.R")
                source("2.5.3 DR_MyFunc.R")
                source("2.0.0 MyFunc.R")
                source("2.0.1 GetProb.R")}
            
            
            
            
            
            dr.sol = NULL
            if (doDR){
              if (doPMLE==F){
                  set.seed(seeds[seed.index])
                pmle4 = PMLEst4(data, cnt)
                results[i,n.index,"PMLE4",,,,] = pmle4
                # pmle4 = results[i,n.index,"PMLE4",,,,] #DEBUG
              }
              
              # (alpha.pmle = mle3[,,1,1])
              # beta.pmle = mle3[,,2,1]
              (alpha.pmle = pmle4[,,1,1])
              beta.pmle = pmle4[,,2,1]
              # alpha.pmle = alpha.true
              # beta.pmle = beta.true
              
              if (gop.type == "incorrect"){
                  tmp = rm.l0.tilde(data, cnt)
                  data = tmp$data
                  cnt = tmp$cnt
                  source("2.5.3 DR_MyFunc.R")
                  nb = 32
                  vec.1 = rep(1,nb)
                  vec.0 = rep(0,nb)
              }
              
              if (y.hat.type == "saturated"){
                  # Use saturated model
                  y.hat = cnt[17:32]/(cnt[17:32]+cnt[1:16])
              }else if (y.hat.type == "MLE"){
                  # Use the invert map
                  DataGen.obj = func.DataGen(c(alpha.pmle,beta.pmle,gamma.true), data,cnt)
                  y.hat = DataGen.obj$p.y[1:16]
              }else if (y.hat.type == "logit"){
                  # Use logistic regression
                  y.hat = glm.fit(y.cov,y.vec,family=binomial(),intercept=FALSE)$coeff
              }else if (y.hat.type == "diy"){
                  y.hat = c(0.5,0,0,0,0)
              }
              
              if (l1.type == "incorrect"){ # then build a logit reg of l1~l0 + a0
                  l1.cov = y.cov[,c("l0.1","l0.2","a0")]
                  l1.hat = glm.fit(l1.cov,l1.vec,family=binomial(),intercept=FALSE)$coeff
                  
              }
              
              #   2.3 Estimate alpha (beta and p0 are given by MLE)
              # j = 1; vals = list() # to output EE value
              set.seed(seeds[seed.index])
              t4 = proc.time()
              results[i,n.index,"DR",,,,] = dr.sol =
                DREst(data,cnt,alpha.pmle,beta.pmle,gamma, y.hat,l1.hat, 
                      SEED = seeds[seed.index], dr.warm="MLE")
              t5 = proc.time()
              print(paste0("DR time: ", t5 - t4))
              # plot(unlist(vals),type='l',
              #      main="EE when alpha,beta use pmle, with saturated OR",
              #      xlab='optim iteration', ylab="EE value")
              # abline(1,0,col="red")
              
              (alpha.dr = dr.sol[, , 1, 1])
              (beta.dr = dr.sol[, , 2, 1])
              # (contrast.dr = dr.sol[, , 1, 4])
              source("2.6 DR_EE_value.R")
              
              # for(j in 1:10){
              #     # j=2
              #     dr.warm = dr.warm.obj = NULL
              #     if (j == 1){
              #         dr.warm="DIY"
              #         dr.warm.obj = alpha.pmle
              #     }else{
              #         dr.warm = "random"
              #     }
              #     dr.res.10[j,i,,,,]= dr.est =DREst(data,cnt,alpha.pmle,beta.pmle,gamma, y.hat,l1.hat, 
              #                                         SEED = seeds[seed.index]+j, dr.warm=dr.warm, dr.warm.obj=dr.warm.obj)
              #     
              #     (alpha.dr = dr.est[, , 1, 1])
              #     (beta.dr = dr.est[, , 2, 1])
              #     ee.val = EE.value(data, cnt, alpha.dr,beta.dr, gamma,
              #                        Hessian=FALSE,y.hat)
              #     dr.res.10[j,i,,,,"nll"] = nll.j = sum(ee.val$tmp^2)
              #     if (nll.j < 0.005){ # Here point.dr.est.MLE$nll is the DR objective value
              #         # Copy and paste this result and stop
              #         if (seed.start < 10){
              #             for (k in (j+1):10){
              #                 dr.res.10[k,i,,,,]= dr.res.10[j,i,,,,]
              #             }
              #         }
              #         break
              #     }
              # }
            }
            
            ############
            GComp.res.i = GComp.res[i,]
            if (doGComp){
              ### Do g-computation for E[Y(a1,a2)|L0=l0]
              # alpha.pmle = pmle4[,,1,1]
              beta.pmle = pmle4[,,2,1]
              y.hat.type = "logit" # "saturated" or "MLE" or "logit" or "diy"
              if (y.hat.type == "logit"){
                  # Use logistic regression
                  y.hat = glm.fit(y.cov,y.vec,family=binomial(),intercept=FALSE)$coeff
              }   
              y.hat.func = function(l0,a0,l1,a1,y.hat){
                  if (class(l0) == "numeric"){
                      res = expit(c(l1,a0,l0,a1) %*% y.hat)
                  }else{
                      res = c(expit(cbind(l1,a0,l0,a1) %*% y.hat))
                  }
                  return(res)
              }
              
              l1.type ="incorrect" 
              if (l1.type == "incorrect"){ # then build a logit reg of l1~l0 + a0
                  l1.cov = y.cov[,c("l0.1","l0.2","a0")]
                  l1.hat = glm.fit(l1.cov,l1.vec,family=binomial(),intercept=FALSE)$coeff
              }
              p.l1.func= function(a0,l0,beta, nb,l1.hat="holder"){
                  if (class(l0) == "numeric"){
                      res = expit(c(l0, a0) %*% l1.hat)
                  }else if (class(l0)=="matrix"){
                      # res = sapply(1:nb, function(i) {expit(c(l0[i,], a0[i]) %*% l1.hat)})
                      res = c(expit(cbind(l0, a0) %*% l1.hat))
                  }
                  return(res)
              }
              
              for (l0 in c(0,1)){
                for (a0 in c(0,1) ){
                  for (a1 in c(0,1) ){
                    l0.vec = c(1,l0)
                    GComp.res[i,4*l0 + 2*a0 + a1 + 1] =GComp.res.i[4*l0 + 2*a0 + a1 + 1] = 
                      y.hat.func(l0=l0.vec,a0,l1=1,a1,y.hat) * p.l1.func(a0,l0.vec,beta.pmle, nb,l1.hat)+
                      y.hat.func(l0=l0.vec,a0,l1=0,a1,y.hat) * (1-p.l1.func(a0,l0.vec,beta.pmle, nb,l1.hat))
                    # print(GComp.res.i[4*l0 + 2*a0 + a1 + 1])
                  }
                }
              } # End of l0 loop
              
              ##### 4 Aug 2021: Store prevalence of L0
              (p.l0.i = p.l0[i] = sum(cnt[which(data[,"l0.2"] == 1)]))/n
            } # End of doGComp
            if (doDebugEE){
                source("2.6 DR_EE_value.R")
                ee.val =
                    EE.value(data, cnt, alpha.dr,beta.dr, gamma,
                             Hessian=FALSE,yhat)
                ee.u1[i] = ee.val$tmp.u1
                ee.u0[i] = ee.val$tmp.u0
                for (r.idx in 1:5){
                    for (c.idx in 1:2){
                        ee.res[i,2*(r.idx-1) + c.idx] = ee.val$tmp[r.idx, c.idx]
                        ee.stage1[i,2*(r.idx-1) + c.idx] = ee.val$tmp.stage1[r.idx, c.idx]
                        ee.d1[i,2*(r.idx-1) + c.idx] = ee.val$tmp.d1[r.idx, c.idx]
                        ee.d0[i,2*(r.idx-1) + c.idx] = ee.val$tmp.d0[r.idx, c.idx]
                        
                    }
                }
            }
            list(mle3=mle3, pmle4=pmle4, dr.sol=dr.sol, GComp.res.i=GComp.res.i,
                 p.l0.i=p.l0.i)
        } # End of i = 1:numsims loop
        for (i in boot.list){
            tmp = boot1[[i-(iter-1)*per.iter]]
            if (doMLE){
                results[i,n.index,"MLE3",,,,] = tmp$mle3
            }
            if (doPMLE){
                results[i,n.index,"PMLE4",,,,] = tmp$pmle4
            }
            if (doDR){
                results[i,n.index,"DR",,,,]  = tmp$dr.sol
            }
            if (doGComp){
                GComp.res[i, ] = tmp$GComp.res.i
                p.l0[i] = tmp$p.l0.i
            }
            
        }
        save(results, GComp.res, params.true, file=paste("./RDataFiles/11Jun2021-parallel-n-1000-cts-l0-theta-not-0-l1-logit-iter-",iter,".RData",sep=""))
    } # End iter
    t1.total = proc.time()
    t1.total - t0.total
    #############
        # Unlist parallel result
        ##############
        
        
        
        # mean.res = results[1,n.index,"PMLE4",,,,]
        # n.index = 4
        
        
        
        
        curr.sim = numsims
        mean.res = GComp.res[1,]
        mean.res.pmle = results[1,n.index,"PMLE4",,,,]
        mean.res.DR = results[1,n.index,"DR",,,,]
        for(i in 2:curr.sim){
          mean.res.pmle = mean.res.pmle +results[i,n.index,"PMLE4",,,,]
          mean.res.DR = mean.res.DR + results[i,n.index,"DR",,,,]
          mean.res = mean.res + GComp.res[i,]
        }
        
        (mean.res.pmle/curr.sim)[, , 1, 1]
        (mean.res.DR/curr.sim)[, , 1, 1]
        mean.res/curr.sim
        
        wide = GComp.res[1:curr.sim,]
        colnames(wide) = c("000","001","010","011",
                          "100","101","110","111")
        wide = data.frame(wide)
        library(data.table)
        long = melt(setDT(wide),measure.vars = 1:8, variable.name = "l0a0a1")
        # cond.g.null.aov = aov(value~l0a0a1, data=long[1:(4*curr.sim),])
        # summary(cond.g.null.aov)
        marginal.data = data.frame(value=(1-p.l0/5000)*long[1:(4*curr.sim),'value']+
                                                (p.l0/5000)* long[(1+4*curr.sim):(8*curr.sim),'value'],
                                    a0a1 = rep(c("00","01","10","11"), each=curr.sim))
        marginal.g.null.aov = aov(value~a0a1, data=marginal.data) # because L0 is Bern(1/2)
        summary(marginal.g.null.aov)
        print(paste("This sample size takes",round((proc.time() - ptm1)[3],2),
                    "s"))
        
    # } # End for n.index in 1:length(n.cand))
    
    rownames(gamma.true) = c("a0","a1")
    datagen = DataGen(100000, SEED=0)
    MyAttach(datagen)
    params.true$contrast.true = GetContrast(data, alpha.true, beta.true)
    rownames(gamma.true) = c("a0","a1")
    
    print(paste("It takes",round((proc.time() - ptm0)[3],2),"s"))

    
    save(results, GComp.res, p.l0, params.true, file=paste("./RDataFiles/15Jul2021-n-1000-theta-0-g-comp-y-misgop",SEED.index,".RData",sep=""))
    
    cluster.path = "~/Desktop/causal res/SNMM_Local/3. Data and Programming/cluster/080820"
    setwd(cluster.path)
    load(paste("./RDataFiles/15Jul2021-n-1000-theta-0-g-comp-y-misgop",1,".RData",sep=""))
    
